// Copyright 2018 Xiaomi, Inc.  All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This is a generated file. DO NOT EDIT!

#include <string>

#include "mace/core/macros.h"
#include "mace/proto/mace.pb.h"
#include "mace/public/mace.h"
#include "mace/utils/env_time.h"
#include "mace/utils/logging.h"

namespace mace {
namespace {{tag}} {

{% for i in range(net.tensors|length) %}
extern void CreateTensor{{ i }}(mace::ConstTensor *tensor);
{% endfor %}


{% for i in range(net.op|length) %}
extern void CreateOperator{{i}}(mace::OperatorDef *op);
{% endfor %}

}  // namespace {{ tag }}

namespace {

{% if net.arg|length != 0 %}
void CreateNetArg(NetDef *net_def) {
  net_def->mutable_arg()->Reserve({{ net.arg|length }});
  mace::Argument *arg = nullptr;
  {% for i in range(net.arg|length) %}

  arg = net_def->add_arg();
  arg->set_name({{ net.arg[i].name|tojson }});

  {%- if net.arg[i].HasField('f') %}
  arg->set_f({{ net.arg[i].f }});
  {% endif %}

  {%- if net.arg[i].HasField('i') %}
  arg->set_i({{ net.arg[i].i }});
  {% endif %}

  {%- if net.arg[i].HasField('s') %}
  arg->set_s({{ net.arg[i].s|tojson }});
  {% endif %}

  arg->mutable_floats()->Reserve({{ net.arg[i].floats|length }});
  {% for float_value in net.arg[i].floats %}
  arg->add_floats({{ float_value }});
  {% endfor %}
  arg->mutable_ints()->Reserve({{ net.arg[i].ints|length }});
  {% for int_value in net.arg[i].ints %}
  arg->add_ints({{ int_value }});
  {% endfor %}

  {% endfor %}
}
{% endif %}

{% if net.input_info | length > 0 %}
void CreateInputInfo(NetDef *net_def) {
  net_def->mutable_input_info()->Reserve({{ net.input_info | length }});
  InputInfo *input_info = nullptr;
  {% for idx in range(net.input_info|length) %}
  input_info = net_def->add_input_info();
  input_info->set_name({{ net.input_info[idx].name|tojson }});
  input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }}));
  input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }});
  {% for dim in net.input_info[idx].dims %}
  input_info->add_dims({{ dim }});
  {% endfor %}
  {% endfor %}
}
{% endif %}

{% if net.output_info | length > 0 %}
void CreateOutputInfo(NetDef *net_def) {
  net_def->mutable_output_info()->Reserve({{ net.output_info | length }});
  OutputInfo *output_info = nullptr;
  {% for idx in range(net.output_info|length) %}
  output_info = net_def->add_output_info();
  output_info->set_name({{ net.output_info[idx].name|tojson }});
  output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }}));
  output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }});
  {% for dim in net.output_info[idx].dims %}
  output_info->add_dims({{dim}});
  {% endfor %}
  {% endfor %}
}
{% endif %}

void CreateOperators(NetDef *net_def) {
  MACE_LATENCY_LOGGER(1, "Create operators");

  net_def->mutable_op()->Reserve({{ net.op|length }});
  {% for i in range(net.op|length) %}
  mace::{{tag}}::CreateOperator{{i}}(net_def->add_op());
  {% endfor %}
}

void CreateTensors(NetDef *net_def) {
  MACE_LATENCY_LOGGER(1, "Create tensors");

  net_def->mutable_tensors()->Reserve({{ net.tensors|length }});
  {% for i in range(net.tensors|length) %}
  mace::{{tag}}::CreateTensor{{ i }}(net_def->add_tensors());
  {% endfor %}
}

{% if net.mem_arena.mem_block|length != 0 %}
void CreateMemoryArena(mace::MemoryArena *mem_arena) {
  mem_arena->mutable_mem_block()->Reserve({{ net.mem_arena.mem_block|length }});
  {% for i in range(net.mem_arena.mem_block|length) %}

  mace::MemoryBlock* mem_block{{i}} = mem_arena->add_mem_block();
  mem_block{{i}}->set_mem_id({{net.mem_arena.mem_block[i].mem_id}});
  mem_block{{i}}->set_device_type(static_cast<DeviceType>({{net.mem_arena.mem_block[i].device_type}}));
  mem_block{{i}}->set_mem_type(static_cast<MemoryType>({{net.mem_arena.mem_block[i].mem_type}}));
  mem_block{{i}}->set_x({{net.mem_arena.mem_block[i].x}});
  mem_block{{i}}->set_y({{net.mem_arena.mem_block[i].y}});

  {% endfor %}
}
{% endif %}

}  // namespace

namespace {{tag}} {

const std::shared_ptr<NetDef> CreateNet() {
  MACE_LATENCY_LOGGER(1, "Create net {{ net.name }}");

  std::shared_ptr<NetDef> net_def(new NetDef());
  net_def->set_name("{{ net.name}}");
  net_def->set_version("{{ net.version }}");

  CreateOperators(net_def.get());
  CreateTensors(net_def.get());

  {% if net.arg|length != 0 %}
  CreateNetArg(net_def.get());
  {% endif %}
  {% if net.mem_arena.mem_block|length != 0 %}
  CreateMemoryArena(net_def->mutable_mem_arena());
  {% endif %}
  {% if net.input_info | length > 0 %}
  CreateInputInfo(net_def.get());
  {% endif %}
  {% if net.output_info | length > 0 %}
  CreateOutputInfo(net_def.get());
  {% endif %}

  return net_def;
}

const std::string ModelName() {
  return {{ tag|tojson }};
}

const std::string ModelChecksum() {
  return {{ checksum|tojson }};
}

const std::string ModelBuildTime() {
  return {{ build_time|tojson }};
}

const std::string ModelBuildOptions() {
  return {{ "runtime: {}, obfuscate: {}, embed_model_data: {}, winograd: {}"
        .format(runtime, obfuscate, embed_model_data, winograd_conv)|tojson }};
}

}  // namespace {{tag}}
}  // namespace mace
